﻿// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
namespace Microsoft.Quantum.QsCompiler.SyntaxProcessing

open System
open System.Collections.Immutable
open Microsoft.Quantum.QsCompiler
open Microsoft.Quantum.QsCompiler.DataTypes
open Microsoft.Quantum.QsCompiler.SyntaxTokens

type SymbolOccurrence =
    | Declaration of QsSymbol
    | UsedType of QsType
    | UsedVariable of QsSymbol
    | UsedLiteral of QsExpression

module SymbolOccurrence =
    let rec flattenSymbol symbol =
        match symbol.Symbol with
        | SymbolTuple items -> Seq.collect flattenSymbol items |> Seq.toList
        | _ -> [ symbol ]

    let symbolDeclarations = flattenSymbol >> List.map Declaration

    let rec inType type_ =
        match type_.Type with
        | UnitType
        | Int
        | BigInt
        | Double
        | Bool
        | String
        | Qubit
        | Result
        | Pauli
        | Range
        | UserDefinedType _
        | TypeParameter _ -> [ UsedType type_ ]
        | ArrayType t -> inType t
        | TupleType ts -> Seq.collect inType ts |> Seq.toList
        | QsTypeKind.Operation ((t1, t2), _)
        | QsTypeKind.Function (t1, t2) -> inType t1 @ inType t2
        | MissingType
        | InvalidType -> []

    let rec inExpression expression =
        match expression.Expression with
        | UnitValue
        | IntLiteral _
        | BigIntLiteral _
        | DoubleLiteral _
        | BoolLiteral _
        | ResultLiteral _
        | PauliLiteral _
        | MissingExpr -> [ UsedLiteral expression ]
        // TODO: Remove special case for un-interpolated strings once overlapping occurrences are handled.
        | StringLiteral (_, es) when es.IsEmpty -> [ UsedLiteral expression ]
        | Identifier (s, ts) ->
            let ts' = QsNullable.defaultValue ImmutableArray.Empty ts |> Seq.collect inType |> Seq.toList
            UsedVariable s :: ts'
        | Lambda lambda ->
            let validDeclaration (decl: SyntaxTree.LocalVariableDeclaration<SyntaxTree.QsLocalSymbol, _>) =
                match decl.VariableName with
                | SyntaxTree.QsLocalSymbol.ValidName name -> Some { Symbol = Symbol name; Range = Value decl.Range }
                | SyntaxTree.QsLocalSymbol.InvalidName -> None

            let declarations =
                lambda.ArgumentTuple.Items
                |> Seq.choose validDeclaration
                |> Seq.map (fun decl -> Declaration decl)
                |> Seq.toList

            declarations @ inExpression lambda.Body
        // TODO: Handle named item accessor.
        | NamedItem (e, _)
        | NEG e
        | NOT e
        | BNOT e
        | UnwrapApplication e
        | AdjointApplication e
        | ControlledApplication e -> inExpression e
        | ArrayItem (e1, e2)
        | ADD (e1, e2)
        | SUB (e1, e2)
        | MUL (e1, e2)
        | DIV (e1, e2)
        | MOD (e1, e2)
        | POW (e1, e2)
        | EQ (e1, e2)
        | NEQ (e1, e2)
        | LT (e1, e2)
        | LTE (e1, e2)
        | GT (e1, e2)
        | GTE (e1, e2)
        | AND (e1, e2)
        | OR (e1, e2)
        | BOR (e1, e2)
        | BAND (e1, e2)
        | BXOR (e1, e2)
        | LSHIFT (e1, e2)
        | RSHIFT (e1, e2)
        | SizedArray (e1, e2)
        | RangeLiteral (e1, e2)
        | CallLikeExpression (e1, e2) -> inExpression e1 @ inExpression e2
        | CONDITIONAL (e1, e2, e3)
        | CopyAndUpdate (e1, e2, e3) -> inExpression e1 @ inExpression e2 @ inExpression e3
        | ValueTuple es
        | StringLiteral (_, es)
        | ValueArray es -> Seq.collect inExpression es |> Seq.toList
        | NewArray (t, e) -> inType t @ inExpression e
        | InvalidExpr -> []

    let rec inInitializer initializer =
        match initializer.Initializer with
        | QubitRegisterAllocation e -> inExpression e
        | QubitTupleAllocation xs -> Seq.collect inInitializer xs |> Seq.toList
        | SingleQubitAllocation
        | InvalidInitializer -> []

    let inGenerator generator =
        match generator.Generator with
        | Intrinsic
        | AutoGenerated
        | FunctorGenerationDirective _ -> []
        | UserDefinedImplementation s -> symbolDeclarations s

    let rec inNamedTuple =
        function
        | QsTuple xs -> Seq.collect inNamedTuple xs |> Seq.toList
        | QsTupleItem (s, t) -> symbolDeclarations s @ inType t

    let inCallable (callable: CallableDeclaration) =
        let typeParams = Seq.map Declaration callable.Signature.TypeParameters |> Seq.toList
        let args = inNamedTuple callable.Signature.Argument
        let returnType = inType callable.Signature.ReturnType
        Declaration callable.Name :: typeParams @ args @ returnType

    [<CompiledName "InFragment">]
    let inFragment fragment =
        match fragment with
        | NamespaceDeclaration s -> [ Declaration s ]
        | OpenDirective (ns, a) ->
            let alias = QsNullable<_>.Map (Declaration >> List.singleton) a |> QsNullable.defaultValue []
            UsedVariable ns :: alias
        | DeclarationAttribute (s, e) -> UsedVariable s :: inExpression e
        | OperationDeclaration c
        | FunctionDeclaration c -> inCallable c
        | BodyDeclaration g
        | AdjointDeclaration g
        | ControlledDeclaration g
        | ControlledAdjointDeclaration g -> inGenerator g
        | ImmutableBinding (s, e)
        | MutableBinding (s, e)
        | ForLoopIntro (s, e) -> symbolDeclarations s @ inExpression e
        | UsingBlockIntro (s, i)
        | BorrowingBlockIntro (s, i) -> symbolDeclarations s @ inInitializer i
        | TypeDefinition t -> Declaration t.Name :: inNamedTuple t.UnderlyingType
        | ExpressionStatement e
        | ReturnStatement e
        | FailStatement e
        | IfClause e
        | ElifClause e
        | WhileLoopIntro e
        | UntilSuccess (e, _) -> inExpression e
        | ValueUpdate (e1, e2) -> inExpression e1 @ inExpression e2
        | ElseClause
        | RepeatIntro
        | WithinBlockIntro
        | ApplyBlockIntro
        | InvalidFragment -> []

// C# interoperability.
type SymbolOccurrence with
    member occurrence.Match
        (
            declaration: Func<QsSymbol, 'a>,
            usedType: Func<QsType, 'a>,
            usedVariable: Func<QsSymbol, 'a>,
            usedLiteral: Func<QsExpression, 'a>
        ) =
        match occurrence with
        | Declaration s -> declaration.Invoke s
        | UsedType t -> usedType.Invoke t
        | UsedVariable s -> usedVariable.Invoke s
        | UsedLiteral e -> usedLiteral.Invoke e

    member occurrence.TryGetDeclaration(symbol: QsSymbol outref) =
        match occurrence with
        | Declaration s ->
            symbol <- s
            true
        | _ -> false

    member occurrence.TryGetUsedType(``type``: QsType outref) =
        match occurrence with
        | UsedType t ->
            ``type`` <- t
            true
        | _ -> false

    member occurrence.TryGetUsedVariable(symbol: QsSymbol outref) =
        match occurrence with
        | UsedVariable s ->
            symbol <- s
            true
        | _ -> false

    member occurrence.TryGetUsedLiteral(expression: QsExpression outref) =
        match occurrence with
        | UsedLiteral e ->
            expression <- e
            true
        | _ -> false
